{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "740cfb33",
   "metadata": {},
   "source": [
    "参考： https://docs.llamaindex.ai/en/stable/examples/index_structs/struct_indices/SQLIndexDemo/"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "e515a275",
   "metadata": {},
   "outputs": [],
   "source": [
    "from llama_index.core import SQLDatabase\n",
    "from llama_index.llms.openai import OpenAI\n",
    "from sqlalchemy import (\n",
    "    create_engine,\n",
    "    MetaData,\n",
    "    Table,\n",
    "    Column,\n",
    "    String,\n",
    "    Integer,\n",
    "    select,\n",
    ")\n",
    "from IPython.display import Markdown, display"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "113ca4f7",
   "metadata": {},
   "outputs": [],
   "source": [
    "engine = create_engine(\"sqlite:///:memory:\")\n",
    "metadata_obj = MetaData()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "a24f4791",
   "metadata": {},
   "outputs": [],
   "source": [
    "# create city SQL table\n",
    "table_name = \"city_stats\"\n",
    "city_stats_table = Table(\n",
    "    table_name,\n",
    "    metadata_obj,\n",
    "    Column(\"city_name\", String(16), primary_key=True),\n",
    "    Column(\"population\", Integer),\n",
    "    Column(\"country\", String(16), nullable=False),\n",
    ")\n",
    "metadata_obj.create_all(engine)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "56dee6f3",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sqlalchemy import insert\n",
    "\n",
    "sql_database = SQLDatabase(engine, include_tables=[\"city_stats\"])\n",
    "\n",
    "rows = [\n",
    "    {\"city_name\": \"Toronto\", \"population\": 2930000, \"country\": \"Canada\"},\n",
    "    {\"city_name\": \"Tokyo\", \"population\": 13960000, \"country\": \"Japan\"},\n",
    "    {\n",
    "        \"city_name\": \"Chicago\",\n",
    "        \"population\": 2679000,\n",
    "        \"country\": \"United States\",\n",
    "    },\n",
    "    {\n",
    "        \"city_name\": \"New York\",\n",
    "        \"population\": 8258000,\n",
    "        \"country\": \"United States\",\n",
    "    },\n",
    "    {\"city_name\": \"Seoul\", \"population\": 9776000, \"country\": \"South Korea\"},\n",
    "    {\"city_name\": \"Busan\", \"population\": 3334000, \"country\": \"South Korea\"},\n",
    "]\n",
    "for row in rows:\n",
    "    stmt = insert(city_stats_table).values(**row)\n",
    "    with engine.begin() as connection:\n",
    "        cursor = connection.execute(stmt)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "5397c08a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[('Toronto', 2930000, 'Canada'), ('Tokyo', 13960000, 'Japan'), ('Chicago', 2679000, 'United States'), ('New York', 8258000, 'United States'), ('Seoul', 9776000, 'South Korea'), ('Busan', 3334000, 'South Korea')]\n"
     ]
    }
   ],
   "source": [
    "# view current table\n",
    "stmt = select(\n",
    "    city_stats_table.c.city_name,\n",
    "    city_stats_table.c.population,\n",
    "    city_stats_table.c.country,\n",
    ").select_from(city_stats_table)\n",
    "\n",
    "with engine.connect() as connection:\n",
    "    results = connection.execute(stmt).fetchall()\n",
    "    print(results)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "901aa284",
   "metadata": {},
   "source": [
    "Query Index"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "3d1e1120",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "('Busan',)\n",
      "('Chicago',)\n",
      "('New York',)\n",
      "('Seoul',)\n",
      "('Tokyo',)\n",
      "('Toronto',)\n"
     ]
    }
   ],
   "source": [
    "from sqlalchemy import text\n",
    "\n",
    "with engine.connect() as con:\n",
    "    rows = con.execute(text(\"SELECT city_name from city_stats\"))\n",
    "    for row in rows:\n",
    "        print(row)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0e3a1047",
   "metadata": {},
   "source": [
    "Part 1: Text-to-SQL Query Engine"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "69c02ca6",
   "metadata": {},
   "outputs": [],
   "source": [
    "from llama_index.core.query_engine import NLSQLTableQueryEngine\n",
    "\n",
    "query_engine = NLSQLTableQueryEngine(\n",
    "    sql_database=sql_database, tables=[\"city_stats\"], llm=llm\n",
    ")\n",
    "query_str = \"Which city has the highest population?\"\n",
    "response = query_engine.query(query_str)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "514d4e47",
   "metadata": {},
   "source": [
    "Part 2: Query-Time Retrieval of Tables for Text-to-SQL"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "7a8ba1e0",
   "metadata": {},
   "outputs": [],
   "source": [
    "from llama_index.core.indices.struct_store.sql_query import (\n",
    "    SQLTableRetrieverQueryEngine,\n",
    ")\n",
    "from llama_index.core.objects import (\n",
    "    SQLTableNodeMapping,\n",
    "    ObjectIndex,\n",
    "    SQLTableSchema,\n",
    ")\n",
    "from llama_index.core import VectorStoreIndex\n",
    "from llama_index.embeddings.openai import OpenAIEmbedding\n",
    "\n",
    "\n",
    "# set Logging to DEBUG for more detailed outputs\n",
    "table_node_mapping = SQLTableNodeMapping(sql_database)\n",
    "table_schema_objs = [\n",
    "    (SQLTableSchema(table_name=\"city_stats\"))\n",
    "]  # add a SQLTableSchema for each table\n",
    "\n",
    "obj_index = ObjectIndex.from_objects(\n",
    "    table_schema_objs,\n",
    "    table_node_mapping,\n",
    "    VectorStoreIndex,\n",
    "    embed_model=OpenAIEmbedding(model=\"text-embedding-3-small\",api_key=os.getenv(\"OPENAI_API_KEY\"), api_base=os.getenv(\"OPENAI_BASE_URL\"),),\n",
    ")\n",
    "query_engine = SQLTableRetrieverQueryEngine(\n",
    "    sql_database, obj_index.as_retriever(similarity_top_k=1)\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8637d6fa",
   "metadata": {},
   "outputs": [],
   "source": [
    "response = query_engine.query(\"Which city has the highest population?\")\n",
    "display(Markdown(f\"{response}\"))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
